diff --git a/datafusion/execution/src/spill_file.rs b/datafusion/execution/src/spill_file.rs index dea54dd5d2ca8..dca5da23f53e1 100644 --- a/datafusion/execution/src/spill_file.rs +++ b/datafusion/execution/src/spill_file.rs @@ -42,7 +42,7 @@ pub trait SpillFile: Send + Sync { /// Writer for spill file backends. pub trait SpillWriter: std::io::Write + Send { - /// Intended for close/sync/commit operations. + /// Intended for close/sync/commit operations. fn finish(&mut self) -> Result<()>; } diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs index 0d5b15fcd2f32..a435360ca3364 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs @@ -36,6 +36,8 @@ use crate::aggregates::{ /// Marker for raw rows -> partial state aggregation. pub(in crate::aggregates) struct PartialMarker; +/// Marker for partial state -> partial state aggregation. +pub(in crate::aggregates) struct PartialReduceMarker; /// Marker for raw rows -> partial state conversion without aggregation. pub(in crate::aggregates) struct PartialSkipMarker; /// Marker for partial state -> final value aggregation. diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/mod.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/mod.rs index 2bb1d119f0d61..0d2495a1b556c 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/mod.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/mod.rs @@ -20,9 +20,11 @@ mod common_ordered; mod final_table; mod ordered_final_table; mod ordered_partial_table; +mod partial_reduce_table; mod partial_table; pub(super) use common::{ - AggregateHashTable, FinalMarker, PartialMarker, PartialSkipMarker, + AggregateHashTable, FinalMarker, PartialMarker, PartialReduceMarker, + PartialSkipMarker, }; pub(super) use common_ordered::OrderedAggregateTable; diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_reduce_table.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_reduce_table.rs new file mode 100644 index 0000000000000..4d94c559436fb --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_reduce_table.rs @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::EmitTo; + +use crate::aggregates::AggregateExec; + +use super::common::{ + AggregateHashTable, AggregateHashTableBuffer, AggregateHashTableState, + MaterializedAggregateOutput, PartialReduceMarker, +}; + +/// Methods specific to the aggregate hash table used in the partial-reduce stage. +impl AggregateHashTable { + pub(in crate::aggregates) fn new( + agg: &AggregateExec, + partition: usize, + output_schema: SchemaRef, + batch_size: usize, + ) -> Result { + Self::new_with_filters( + agg, + partition, + output_schema, + batch_size, + vec![None; agg.aggr_expr.len()], + ) + } + + /// Emits the next batch of aggregated group keys and aggregate states. + /// + /// The output batch size is determined by `self.batch_size`. + /// + /// Returns `Some(batch)` for each emitted batch, `None` when output is + /// exhausted, and an internal error if polled in the `Building` state. + pub(in crate::aggregates) fn next_output_batch( + &mut self, + ) -> Result> { + let output_schema = Arc::clone(&self.output_schema); + let batch_size = self.batch_size; + // Take ownership of the output state. Note `emit_next_materialized_batch` + // updates state after it emits a materialized slice. + match std::mem::replace(&mut self.state, AggregateHashTableState::Done) { + AggregateHashTableState::Outputting(state) => { + if state.group_values.is_empty() { + return Ok(None); + } + + let output = + self.materialize_partial_reduce_output(state, output_schema)?; + Ok(self.emit_next_materialized_batch(output, batch_size)) + } + AggregateHashTableState::OutputtingMaterialized(output) => { + Ok(self.emit_next_materialized_batch(output, batch_size)) + } + AggregateHashTableState::Done => Ok(None), + AggregateHashTableState::Building(_) => { + internal_err!("next_output_batch must be called in the outputting state") + } + } + } + + fn materialize_partial_reduce_output( + &self, + mut state: AggregateHashTableBuffer, + output_schema: SchemaRef, + ) -> Result { + // `state(EmitTo::All)` consumes accumulator state. Emit all groups once, + // then slice the materialized batch on subsequent polls. + let emit_to_all = EmitTo::All; + let timer = self.group_by_metrics.emitting_time.timer(); + let mut output = state.group_values.emit(emit_to_all)?; + + for acc in state.accumulators.iter_mut() { + output.extend(acc.state(emit_to_all)?); + } + drop(timer); + + let batch = RecordBatch::try_new(output_schema, output)?; + debug_assert!(batch.num_rows() > 0); + Ok(MaterializedAggregateOutput::new(batch)) + } + + fn emit_next_materialized_batch( + &mut self, + mut output: MaterializedAggregateOutput, + batch_size: usize, + ) -> Option { + let batch = output.next_batch(batch_size); + if output.is_exhausted() { + self.state = AggregateHashTableState::Done; + } else { + self.state = AggregateHashTableState::OutputtingMaterialized(output); + } + batch + } + + pub(in crate::aggregates) fn aggregate_batch( + &mut self, + batch: &RecordBatch, + ) -> Result<()> { + let evaluated_batch = self.evaluate_batch(batch)?; + let state = self.state.building_mut(); + + let timer = self.group_by_metrics.aggregation_time.timer(); + for group_values in &evaluated_batch.grouping_set_args { + state + .group_values + .intern(group_values, &mut state.batch_group_indices)?; + let group_indices = &state.batch_group_indices; + let total_num_groups = state.group_values.len(); + + for (acc, values) in state + .accumulators + .iter_mut() + .zip(evaluated_batch.accumulator_args.iter()) + { + acc.merge_batch(values, group_indices, total_num_groups)?; + } + } + drop(timer); + + Ok(()) + } + + pub(in crate::aggregates) fn start_output(&mut self) -> Result<()> { + self.start_outputting(); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index b73253f1e8e50..11446137f3ca1 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -26,6 +26,7 @@ use crate::aggregates::{ no_grouping::AggregateStream, ordered_final_stream::OrderedFinalAggregateStream, ordered_partial_stream::OrderedPartialAggregateStream, + partial_reduce_stream::PartialReduceHashAggregateStream, row_hash::GroupedHashAggregateStream, topk_stream::GroupedTopKAggregateStream, }; @@ -81,6 +82,7 @@ mod no_grouping; pub mod order; mod ordered_final_stream; mod ordered_partial_stream; +mod partial_reduce_stream; mod row_hash; mod skip_partial; mod topk; @@ -531,6 +533,9 @@ enum StreamType { /// Partial stage of the hash aggregation /// Input output scheme: initial input -> partial state PartialHash(PartialHashAggregateStream), + /// Partial-reduce stage of the hash aggregation + /// Input output scheme: partial state -> partial state + PartialReduceHash(PartialReduceHashAggregateStream), /// Final stage of the hash aggregation /// Input output scheme: partial state -> final result FinalHash(FinalHashAggregateStream), @@ -560,6 +565,7 @@ impl From for SendableRecordBatchStream { match stream { StreamType::AggregateStream(stream) => Box::pin(stream), StreamType::PartialHash(stream) => Box::pin(stream), + StreamType::PartialReduceHash(stream) => Box::pin(stream), StreamType::FinalHash(stream) => Box::pin(stream), StreamType::OrderedPartialAggregate(stream) => Box::pin(stream), StreamType::OrderedFinalAggregate(stream) => Box::pin(stream), @@ -1048,6 +1054,12 @@ impl AggregateExec { )?)); } + if self.should_use_partial_reduce_hash_stream(context) { + return Ok(StreamType::PartialReduceHash( + PartialReduceHashAggregateStream::new(self, context, partition)?, + )); + } + if self.should_use_ordered_final_aggregate_stream(context) { return Ok(StreamType::OrderedFinalAggregate( OrderedFinalAggregateStream::new(self, context, partition)?, @@ -1108,6 +1120,19 @@ impl AggregateExec { && self.group_by.is_single() } + fn should_use_partial_reduce_hash_stream(&self, context: &TaskContext) -> bool { + // TODO: implement memory-limited path and remove this limitation + if matches!(context.memory_pool().memory_limit(), MemoryLimit::Finite(_)) { + return false; + } + + self.mode == AggregateMode::PartialReduce + && self.limit_options.is_none() + && self.input_order_mode == InputOrderMode::Linear + && !self.group_by.is_true_no_grouping() + && self.group_by.is_single() + } + fn should_use_ordered_final_aggregate_stream(&self, context: &TaskContext) -> bool { // TODO: implement memory-limited path and remove this limitation if matches!(context.memory_pool().memory_limit(), MemoryLimit::Finite(_)) { @@ -3481,6 +3506,99 @@ mod tests { Ok(()) } + fn partial_reduce_test_aggregate() -> Result { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::Float64, false), + ])); + let group_by = + PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]); + let aggregates: Vec> = vec![Arc::new( + AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("SUM(b)") + .build()?, + )]; + + let empty_input = + TestMemoryExec::try_new_exec(&[vec![]], Arc::clone(&schema), None)?; + let partial = AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + aggregates.clone(), + vec![None], + empty_input, + Arc::clone(&schema), + )?; + let partial_schema = partial.schema(); + let partial_state_batch = RecordBatch::try_new( + Arc::clone(&partial_schema), + vec![ + Arc::new(UInt32Array::from(vec![1, 2, 1, 3])), + Arc::new(Float64Array::from(vec![10.0, 20.0, 40.0, 30.0])), + ], + )?; + let partial_reduce_input = TestMemoryExec::try_new_exec( + &[vec![partial_state_batch]], + Arc::clone(&partial_schema), + None, + )?; + + AggregateExec::try_new( + AggregateMode::PartialReduce, + group_by, + aggregates, + vec![None], + partial_reduce_input, + partial_schema, + ) + } + + /// For partial-reduce aggregation, ensures `PartialReduceHashAggregateStream` + /// is used when enabled by migration config. + #[tokio::test] + async fn partial_reduce_aggregate_planning() -> Result<()> { + let partial_reduce = partial_reduce_test_aggregate()?; + let task_ctx = Arc::new( + TaskContext::default().with_session_config( + SessionConfig::new() + .set_bool("datafusion.execution.enable_migration_aggregate", true), + ), + ); + + let stream = partial_reduce.execute_typed(0, &task_ctx)?; + assert!(matches!(stream, StreamType::PartialReduceHash(_))); + let stream: SendableRecordBatchStream = stream.into(); + let output = collect(stream).await?; + assert_eq!(output.iter().map(RecordBatch::num_rows).sum::(), 3); + + Ok(()) + } + + /// Spilling behavior is not implemented for partial-reduce stream yet, so fall + /// back to the existing `GroupedHashAggregateStream` + #[tokio::test] + async fn partial_reduce_aggregate_with_memory_limit_planning() -> Result<()> { + let partial_reduce = partial_reduce_test_aggregate()?; + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(1, 1.0) + .build_arc()?; + let task_ctx = + Arc::new( + TaskContext::default() + .with_session_config(SessionConfig::new().set_bool( + "datafusion.execution.enable_migration_aggregate", + true, + )) + .with_runtime(runtime), + ); + + let stream = partial_reduce.execute_typed(0, &task_ctx)?; + assert!(matches!(stream, StreamType::GroupedHash(_))); + + Ok(()) + } + /// Ensures for ordered input, `OrderedPartialAggregateStream` is used. #[tokio::test] async fn ordered_partial_aggregate_planning() -> Result<()> { diff --git a/datafusion/physical-plan/src/aggregates/partial_reduce_stream.rs b/datafusion/physical-plan/src/aggregates/partial_reduce_stream.rs new file mode 100644 index 0000000000000..1a4980c89851a --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/partial_reduce_stream.rs @@ -0,0 +1,385 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Partial-reduce hash aggregation stream implementation. +//! +//! This stream is part of the incremental migration from +//! [`crate::aggregates::row_hash::GroupedHashAggregateStream`]. +//! +//! See issue for details: + +use std::ops::ControlFlow; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::Result; +use datafusion_execution::TaskContext; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use futures::stream::{Stream, StreamExt}; + +use super::AggregateExec; +use super::aggregate_hash_table::{AggregateHashTable, PartialReduceMarker}; +use crate::metrics::{BaselineMetrics, RecordOutput, SpillMetrics}; +use crate::stream::EmptyRecordBatchStream; +use crate::{InputOrderMode, RecordBatchStream, SendableRecordBatchStream}; + +/// Hash aggregation can combine multiple partial stages before final +/// evaluation. This stream implements the partial-reduce stage. +/// +/// # Example +/// +/// SELECT k, AVG(v) FROM t GROUP BY k; +/// +/// ## Plan +/// AggregateExec(stage=final) +/// -- RepartitionExec(hash(k)) +/// ---- AggregateExec(stage=partial_reduce) +/// ------ RepartitionExec(hash(k)) +/// -------- AggregateExec(stage=partial) +/// +/// Note: the example plan is only intended to demonstrate this stream's semantics; +/// the default DataFusion SQL planner does not produce plans in this shape. +/// +/// This stream implements the middle partial-reduce aggregation in the plan above. +/// +/// The motivation is to reduce shuffling traffic in a distributed setting. See +/// +/// +/// ## Partial-Reduce Stage Behavior +/// Input: partial aggregate state rows +/// Output: merged partial aggregate state rows +/// +/// This stage is useful for tree-reduce plans. It consumes the same schema as +/// a final aggregate stage, but emits the same schema as a partial aggregate +/// stage. +pub(crate) struct PartialReduceHashAggregateStream { + /// Output schema: group columns followed by partial aggregate state columns. + schema: SchemaRef, + + /// Input batches containing partial aggregate state rows. + input: SendableRecordBatchStream, + + /// Execution metrics shared with the aggregate plan node. + baseline_metrics: BaselineMetrics, + + /// Memory reservation for group keys and accumulators. + reservation: MemoryReservation, + + /// Tracks the high-level stream lifecycle. The hash table owns the lower-level + /// state for emitting output batches. + state: Option, +} + +/// States for partial-reduce hash aggregation processing. +// The typestate pattern mirrors the final stream and keeps the input/output +// semantics explicit for this mode. +enum PartialReduceHashAggregateState { + ReadingInput { + hash_table: AggregateHashTable, + }, + ProducingOutput { + hash_table: AggregateHashTable, + }, + Done, +} + +type PartialReduceHashAggregatePoll = Poll>>; +type PartialReduceHashAggregateStateTransition = ControlFlow< + ( + PartialReduceHashAggregatePoll, + PartialReduceHashAggregateState, + ), + PartialReduceHashAggregateState, +>; + +impl PartialReduceHashAggregateState { + fn hash_table(&self) -> &AggregateHashTable { + match self { + Self::ReadingInput { hash_table } | Self::ProducingOutput { hash_table } => { + hash_table + } + Self::Done => unreachable!("Done state does not hold a hash table"), + } + } + + fn hash_table_mut(&mut self) -> &mut AggregateHashTable { + match self { + Self::ReadingInput { hash_table } | Self::ProducingOutput { hash_table } => { + hash_table + } + Self::Done => unreachable!("Done state does not hold a hash table"), + } + } + + fn into_hash_table(self) -> AggregateHashTable { + match self { + Self::ReadingInput { hash_table } | Self::ProducingOutput { hash_table } => { + hash_table + } + Self::Done => unreachable!("Done state does not hold a hash table"), + } + } + + fn into_producing_output(self) -> Self { + Self::ProducingOutput { + hash_table: self.into_hash_table(), + } + } + + fn into_done(self) -> Self { + Self::Done + } +} + +impl PartialReduceHashAggregateStream { + pub fn new( + agg: &AggregateExec, + context: &Arc, + partition: usize, + ) -> Result { + debug_assert_eq!(agg.mode, super::AggregateMode::PartialReduce); + debug_assert_eq!(agg.input_order_mode, InputOrderMode::Linear); + + let schema = Arc::clone(&agg.schema); + let input = agg.input.execute(partition, Arc::clone(context))?; + let batch_size = context.session_config().batch_size(); + let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition); + + // Preserve the existing aggregate metric surface for this plan node. + let _spill_metrics = SpillMetrics::new(&agg.metrics, partition); + + let hash_table = AggregateHashTable::::new( + agg, + partition, + Arc::clone(&schema), + batch_size, + )?; + + let reservation = + MemoryConsumer::new(format!("PartialReduceHashAggregateStream[{partition}]")) + .register(context.memory_pool()); + + Ok(Self { + schema, + input, + baseline_metrics, + reservation, + state: Some(PartialReduceHashAggregateState::ReadingInput { hash_table }), + }) + } + + fn start_output( + &mut self, + hash_table: &mut AggregateHashTable, + ) -> Result<()> { + let input_schema = self.input.schema(); + self.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); + hash_table.start_output() + } + + /// Handle ReadingInput state - aggregate partial state batches into the hash table. + /// + /// See comments at `poll_next()` for details. + /// + /// Returns the next operator state with control flow decision. + fn handle_reading_input( + &mut self, + cx: &mut Context<'_>, + mut original_state: PartialReduceHashAggregateState, + ) -> PartialReduceHashAggregateStateTransition { + debug_assert!(matches!( + &original_state, + PartialReduceHashAggregateState::ReadingInput { .. } + )); + debug_assert!(original_state.hash_table().is_building()); + + match self.input.poll_next_unpin(cx) { + Poll::Pending => ControlFlow::Break((Poll::Pending, original_state)), + // Get a new input batch, aggregate it in the hash table + Poll::Ready(Some(Ok(batch))) => { + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + let timer = elapsed_compute.timer(); + let result = original_state.hash_table_mut().aggregate_batch(&batch); + timer.done(); + + if let Err(e) = result { + return ControlFlow::Break(( + Poll::Ready(Some(Err(e))), + original_state, + )); + } + + if let Err(e) = self + .reservation + .try_resize(original_state.hash_table().memory_size()) + { + return ControlFlow::Break(( + Poll::Ready(Some(Err(e))), + original_state, + )); + } + + ControlFlow::Continue(original_state) + } + Poll::Ready(Some(Err(e))) => { + ControlFlow::Break((Poll::Ready(Some(Err(e))), original_state)) + } + // Input ends, move to output state + Poll::Ready(None) => { + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + let timer = elapsed_compute.timer(); + let result = self.start_output(original_state.hash_table_mut()); + timer.done(); + + match result { + Ok(()) => { + ControlFlow::Continue(original_state.into_producing_output()) + } + Err(e) => { + ControlFlow::Break((Poll::Ready(Some(Err(e))), original_state)) + } + } + } + } + } + + /// Handle ProducingOutput state - emit merged partial aggregate state batches. + /// + /// See comments at `poll_next()` for details. + /// + /// Returns the next operator state with control flow decision. + fn handle_producing_output( + &mut self, + mut original_state: PartialReduceHashAggregateState, + ) -> PartialReduceHashAggregateStateTransition { + debug_assert!(matches!( + &original_state, + PartialReduceHashAggregateState::ProducingOutput { .. } + )); + debug_assert!(!original_state.hash_table().is_building()); + + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + let timer = elapsed_compute.timer(); + let result = original_state.hash_table_mut().next_output_batch(); + timer.done(); + + match result { + Ok(Some(batch)) => { + let _ = self + .reservation + .try_resize(original_state.hash_table().memory_size()); + debug_assert!(batch.num_rows() > 0); + let next_state = if original_state.hash_table().is_done() { + original_state.into_done() + } else { + original_state + }; + + ControlFlow::Break(( + Poll::Ready(Some(Ok(batch.record_output(&self.baseline_metrics)))), + next_state, + )) + } + Ok(None) => { + let _ = self.reservation.try_resize(0); + ControlFlow::Continue(original_state.into_done()) + } + Err(e) => ControlFlow::Break((Poll::Ready(Some(Err(e))), original_state)), + } + } +} + +impl Stream for PartialReduceHashAggregateStream { + type Item = Result; + + /// Entry point for the partial-reduce hash aggregate state machine. + /// + /// See comments in [`PartialReduceHashAggregateStream`] for high-level ideas. + /// + /// State transition graph: + /// + /// ```text + /// (start) + /// -> ReadingInput + /// The stream starts by polling partial-state input and merging those + /// states into the partial-reduce hash table. + /// + /// ReadingInput + /// -> ReadingInput + /// Aggregate one partial-state input batch, update the inner aggregate + /// hash table, and continue with the next input batch. + /// + /// -> ProducingOutput + /// Input was exhausted. Move to the next state to start outputting + /// merged partial aggregate states. + /// + /// ProducingOutput + /// -> ProducingOutput + /// One merged partial-state output batch was yielded; repeat to + /// continue producing output incrementally. + /// + /// -> Done + /// All merged partial-state output was emitted. + /// + /// Done + /// -> (end) + /// ``` + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + loop { + let cur_state = self + .state + .take() + .expect("PartialReduceHashAggregateStream state should not be None"); + + let next_state = match cur_state { + state @ PartialReduceHashAggregateState::ReadingInput { .. } => { + self.handle_reading_input(cx, state) + } + state @ PartialReduceHashAggregateState::ProducingOutput { .. } => { + self.handle_producing_output(state) + } + state @ PartialReduceHashAggregateState::Done => { + let _ = self.reservation.try_resize(0); + self.state = Some(state); + return Poll::Ready(None); + } + }; + + match next_state { + ControlFlow::Continue(next_state) => { + self.state = Some(next_state); + continue; + } + ControlFlow::Break((poll, next_state)) => { + self.state = Some(next_state); + return poll; + } + } + } + } +} + +impl RecordBatchStream for PartialReduceHashAggregateStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +}